{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 构建语音识别系统 - 模型训练"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 构建模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "本次使用拟采用CTC的结构"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "CTC的结构如下图所示：\n",
    "\n",
    "<img src=\"pic/CTCModel.png\" alt=\"CTC Model\" width=\"400\" />\n",
    "\n",
    "wavfrontend：进一步提取语音特征，做一下数据集的下采样\n",
    "\n",
    "encoder：对语音特征进行编码，得到编码后的特征向量\n",
    "\n",
    "softmax：将编码的特征向量通过线性层映射到词典大小，类似分类任务，预测出概率最大的词"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from typing import Tuple, Union\n",
    "from model.subsampling import Subsampling\n",
    "import math\n",
    "from tokenizer.tokenizer import Tokenizer\n",
    "\n",
    "class CTCModel(nn.Module):\n",
    "    def __init__(self, in_dim, output_size,\n",
    "                 vocab_size, blank_id\n",
    "                ):\n",
    "        super().__init__()\n",
    "        self.subsampling = Subsampling(in_dim, output_size, subsampling_type=8)\n",
    "        encoder_layer = nn.TransformerEncoderLayer(d_model=output_size, nhead=8)\n",
    "        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=3)\n",
    "\n",
    "        self.fc_out = nn.Linear(output_size, vocab_size)\n",
    "\n",
    "        self.ctc_loss = torch.nn.CTCLoss(blank=blank_id,\n",
    "                                         reduction=\"sum\",\n",
    "                                         zero_infinity=True)\n",
    "\n",
    "    def forward(self, x, audio_lens, text, text_lens):\n",
    "        x, encoder_out_lens = self.subsampling(x, audio_lens)\n",
    "        x = self.encoder(x)\n",
    "        predict = self.fc_out(x)\n",
    "\n",
    "        predict = predict.transpose(0, 1)\n",
    "        predict = predict.log_softmax(2)\n",
    "        loss = self.ctc_loss(predict, text, encoder_out_lens, text_lens)\n",
    "        loss = loss / predict.size(1)\n",
    "        predict = predict.transpose(0, 1)\n",
    "        return predict, loss, encoder_out_lens\n",
    "    \n",
    "    def inference(self, x, audio_lens):\n",
    "        x, encoder_out_lens = self.subsampling(x, audio_lens)\n",
    "        x = self.encoder(x)\n",
    "        predict = self.fc_out(x)\n",
    "        predict = predict.log_softmax(2)\n",
    "\n",
    "        return predict, encoder_out_lens\n",
    "\n",
    "tokenizer = Tokenizer(\"./tokenizer/vocab.txt\")\n",
    "model = CTCModel(80, 256, tokenizer.size(), tokenizer.blk_id())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "上面的模型是采用了 Transformer Encoder 作为声学模型，loss采用CTC loss。位于`model.model.CTCModel`。这里模型帮大家写好了。\n",
    "\n",
    "这里简单说明一下CTC：在语音识别的时候，假设神经网络的输出长度为$L_{audio}$，识别文本经过tokenzier分词后，长度为$L_{text}$。一般情况下，$L_{audio} > L_{text}$\n",
    "\n",
    "参考下面的代码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "input = torch.load(\"./example1.pt\")\n",
    "audios = input[\"audios\"]\n",
    "audio_lens = input[\"audio_lens\"]\n",
    "texts = input[\"texts\"]\n",
    "text_lens = input[\"text_lens\"]\n",
    "\n",
    "pre, loss, pre_lens = model(audios, audio_lens, texts, text_lens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[64, 41, 75, 60, 56, 72, 19, 70, 59, 47, 41, 32, 80, 31, 26, 65]\n",
      "[24, 14, 27, 21, 20, 22, 6, 27, 20, 17, 14, 12, 26, 12, 9, 25]\n"
     ]
    }
   ],
   "source": [
    "print(pre_lens.view(-1).tolist())\n",
    "print(text_lens.view(-1).tolist())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "CTC的思路：\n",
    "\n",
    "<img src=\"pic/CTC.png\" alt=\"CTC\" width=\"800\" />"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "CTC引入了一个特殊字符，例如图中的$\\epsilon$。对应我们tokenizer的special tokens里面的 \\<blk\\>。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['<pad>', '<unk>', '<sos>', '<eos>', ' ', '<blk>']\n"
     ]
    }
   ],
   "source": [
    "print(tokenizer.special_token)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "CTC对输入的每一个时间步进行预测，根据对应的输出结果进行两步的处理：\n",
    "\n",
    "- 将重复的字符压缩为一个\n",
    "- 移除掉 $\\epsilon$\n",
    "\n",
    "例如上图中的一个解码路径： [$\\epsilon$, a, $\\epsilon$, $\\epsilon$, b, $\\epsilon$]\n",
    "\n",
    "- 经过第一个步骤， [$\\epsilon$, a, $\\epsilon$, b, $\\epsilon$]\n",
    "- 经过第二个步骤，[a, b]\n",
    "\n",
    "要得到[a,b]有许多情况，在训练的时候，目标是把这些路径的概率最大。\n",
    "\n",
    "更加详细的参考[CTC](https://distill.pub/2017/ctc)。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练框架例如了 `run.py`里面的，大家可以根据自己的需求修改一下。\n",
    "\n",
    "**TODO:** 把训练时候的loss记录下来，并作出曲线。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
